{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade transformers -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['USE_TORCH'] = 'True'  # To use transformers library in TPU\n",
    "os.environ['XLA_USE_BF16'] = 'True'\n",
    "os.environ['PJRT_DEVICE'] = 'TPU'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "papermill": {
     "duration": 13.608021,
     "end_time": "2023-11-04T12:34:21.013846",
     "exception": false,
     "start_time": "2023-11-04T12:34:07.405825",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/tmp/ipykernel_91021/1636890477.py:13: DeprecationWarning: Importing from `torch_xla.experimental.xla_sharded_tensor` will be deprecated after 2.2 release. Please use `torch_xla.distributed.spmd` instead.\n",
      "  from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor\n",
      "/tmp/ipykernel_91021/1636890477.py:14: DeprecationWarning: Importing from `torch_xla.experimental.xla_sharding` will be deprecated after 2.2 release. Please use `torch_xla.distributed.spmd` instead.\n",
      "  from torch_xla.experimental.xla_sharding import Mesh\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import contextlib\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.distributed as dist\n",
    "import torch_xla.core.xla_model as xm\n",
    "import torch_xla.runtime as xr\n",
    "xr.use_spmd()\n",
    "\n",
    "import torch_xla.distributed.spmd as xs\n",
    "from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor\n",
    "from torch_xla.experimental.xla_sharding import Mesh\n",
    "\n",
    "import torch_xla.distributed.xla_multiprocessing as xmp\n",
    "import torch_xla.distributed.parallel_loader as pl\n",
    "import torch_xla.test.test_utils as test_utils\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator\n",
    "from datasets import Dataset, load_dataset, concatenate_datasets\n",
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "from safetensors.torch import load_file\n",
    "\n",
    "import logging\n",
    "logging.getLogger(\"datasets\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"transformers\").setLevel(logging.WARNING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert xr.is_spmd()==True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'trainer_lib.model_partitioning' from '/home/RoadrunnerX/llama3_pytorch_xla/trainer_lib/model_partitioning.py'>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "import importlib\n",
    "sys.path.append('')\n",
    "model_partitioning = importlib.import_module('trainer_lib.model_partitioning')\n",
    "importlib.reload(model_partitioning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "Please provide your HUGGINGFACE_TOKEN:  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\n"
     ]
    }
   ],
   "source": [
    "# This notebook can be used to train any of the 7B, 8B models. Check out the 80B notebook to train bigger model.\n",
    "supported_models = [\n",
    "    \"meta-llama/Meta-Llama-3-8B\",\n",
    "    \"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
    "    \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "    \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "    \"meta-llama/Llama-2-7b-hf\",\n",
    "    \"meta-llama/Llama-2-13b-hf\",\n",
    "    \"TinyLlama/TinyLlama-1.1B-step-50K-105b\",\n",
    "]\n",
    "\n",
    "# Select a supported model from above list to use!\n",
    "MODEL_NAME = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "HUGGINGFACE_TOKEN = input(\"Please provide your HUGGINGFACE_TOKEN: \") # YOUR_HF_TOKEN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure LoRA config for your model.\n",
    "Use the below code to configure the LoRA config for your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_lora(*, model, lora_rank=None, lora_alpha=None, lora_dropout=None):\n",
    "    \"\"\"Applies LoRA configuration to the model.\"\"\"\n",
    "    peft_config = LoraConfig(\n",
    "        task_type=TaskType.CAUSAL_LM,\n",
    "        inference_mode=False,\n",
    "        r=8 if not lora_rank else lora_rank,\n",
    "        lora_alpha=32 if not lora_alpha else lora_alpha,\n",
    "        lora_dropout=0.1 if not lora_dropout else lora_dropout,\n",
    "    )\n",
    "    model = get_peft_model(model, peft_config)\n",
    "    model.print_trainable_parameters()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model(*, model_name, hugging_face_token):\n",
    "    \"\"\"Downloads and initializes the model.\"\"\"\n",
    "    config = AutoConfig.from_pretrained(\n",
    "        model_name, \n",
    "        token=hugging_face_token)\n",
    "    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        model_name, \n",
    "        token=hugging_face_token\n",
    "    )\n",
    "\n",
    "    if not tokenizer.pad_token:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "        config.pad_token_id = tokenizer.pad_token_id\n",
    "        \n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name, \n",
    "        token=hugging_face_token,\n",
    "        low_cpu_mem_usage=True\n",
    "    )\n",
    "\n",
    "    # model = apply_lora(\n",
    "    #     model=model,\n",
    "    #     lora_rank=TRAINER_CONFIG[\"lora_rank\"],\n",
    "    #     lora_alpha=TRAINER_CONFIG[\"lora_alpha\"],\n",
    "    #     lora_dropout=TRAINER_CONFIG[\"lora_dropout\"],\n",
    "    # )\n",
    "\n",
    "    return model, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "papermill": {
     "duration": 0.547357,
     "end_time": "2023-11-04T12:34:26.034497",
     "exception": false,
     "start_time": "2023-11-04T12:34:25.48714",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def apply_spmd(*, model, mesh):\n",
    "    # Apply on layers within model.\n",
    "    model_partitioning_util.partition_model(model, mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure dataset pipeline for your model\n",
    "\n",
    "For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.\n",
    "\n",
    "It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):\n",
    "    # Define Alpaca prompt template\n",
    "    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "    \n",
    "    ### Instruction: {}\n",
    "    \n",
    "    ### Input: {}\n",
    "    \n",
    "    ### Response: {}\"\"\"\n",
    "    \n",
    "    EOS_TOKEN = tokenizer.eos_token\n",
    "    \n",
    "    # Define formatting function.\n",
    "    def _format_prompts(examples):\n",
    "        instructions = examples[\"instruction\"]\n",
    "        inputs = examples[\"input\"]\n",
    "        outputs = examples[\"output\"]\n",
    "        texts = []\n",
    "        for instruction, input, output in zip(instructions, inputs, outputs):\n",
    "            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n",
    "            texts.append(text)\n",
    "        return {\"text\": texts}\n",
    "\n",
    "    # Tokenize the dataset.\n",
    "    def _tokenize(examples):\n",
    "        # Tokenized is list within list. Compute labels for causalLM by shifting input_id; \n",
    "        # consequently truncate input_id to penultimate position.\n",
    "        tokenized = tokenizer(examples[\"text\"], truncation=True, padding=\"max_length\", max_length=seq_length+1)\n",
    "        labels = tokenized['input_ids'].copy()\n",
    "        tokenized['labels'] = [label[1:] for label in labels]\n",
    "        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]\n",
    "        return tokenized\n",
    "\n",
    "    # Load and preprocess the dataset.\n",
    "    dataset = load_dataset(\"yahma/alpaca-cleaned\", split=\"train\")\n",
    "    if max_examples:\n",
    "        dataset = dataset.select(range(max_examples))\n",
    "    dataset = dataset.map(_format_prompts, batched=True)\n",
    "\n",
    "    # Create train and test dataset.\n",
    "    ds = dataset.train_test_split(test_size=0.15)\n",
    "    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "\n",
    "    # Create DataLoader\n",
    "    train_dataloader = torch.utils.data.DataLoader(\n",
    "        ds['train'],\n",
    "        shuffle=True,\n",
    "        batch_size=batch_size,\n",
    "        collate_fn=default_data_collator,\n",
    "    )\n",
    "    \n",
    "    test_dataloader = torch.utils.data.DataLoader(\n",
    "        ds['test'],\n",
    "        shuffle=True,\n",
    "        batch_size=batch_size,\n",
    "        collate_fn=default_data_collator,\n",
    "    )\n",
    "\n",
    "    return train_dataloader, test_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model\n",
    "\n",
    "Now let's train the model. We are using PyTorch XLA's Fully Sharded Data Parallel (FSDP) to distribute the model across the 8 TPU cores available on TPU v3-8. This approach allows for efficient training on TPU hardware. We also utilize PyTorch/XLA's MpDeviceLoader to efficiently load data onto the TPU cores.\n",
    "\n",
    "**NOTE:** It's important to note that the **first step of training will be slow**. This is because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using compiled+cached graph, and leveraging the full power of the all TPU cores for accelerated training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_training_update(step,\n",
    "                          loss,\n",
    "                          epoch=None,\n",
    "                         ):\n",
    "    \"\"\"Prints the training metrics at a given step.\"\"\"\n",
    "    if xm.is_master_ordinal():  # Only print on the master device\n",
    "        update_data = [\n",
    "            'Training',\n",
    "            f'Epoch={epoch}' if epoch is not None else 0,\n",
    "            f'Step={step}',\n",
    "            f'Loss={loss:.5f}',\n",
    "        ]\n",
    "        print(' | '.join(item for item in update_data if item), flush=True)\n",
    "        print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.07it/s]\n"
     ]
    }
   ],
   "source": [
    "model, tokenizer = init_model(\n",
    "        model_name=MODEL_NAME, hugging_face_token=HUGGINGFACE_TOKEN\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(index):\n",
    "    global model, tokenizer, trainer_config\n",
    "\n",
    "    print(trainer_config)\n",
    "    \n",
    "    torch.manual_seed(99)\n",
    "    device = xm.xla_device()\n",
    "    model = model.to(device)\n",
    "    \n",
    "    # Create a mesh for the model partitioning.\n",
    "    num_devices = xr.global_runtime_device_count()\n",
    "    mesh_shape = (1, num_devices, 1)\n",
    "    device_ids = np.array(range(num_devices))\n",
    "    mesh = Mesh(device_ids, mesh_shape, (\"dp\", \"fsdp\", \"mp\"))\n",
    "        \n",
    "    # Partition the model using SPMD.\n",
    "    model_partitioning.partition_model(model=model, mesh=mesh)\n",
    "    \n",
    "    # Configure the training loop.\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=trainer_config.learning_rate)\n",
    "    \n",
    "    train_dataloader, test_dataloader = get_dataset(\n",
    "        tokenizer=tokenizer,\n",
    "        batch_size=trainer_config.batch_size,\n",
    "        seq_length=trainer_config.seq_length,\n",
    "        max_examples=trainer_config.max_examples,\n",
    "    )\n",
    "    train_dataloader = pl.MpDeviceLoader(\n",
    "        train_dataloader, \n",
    "        device\n",
    "    )\n",
    "    \n",
    "    test_dataloader = pl.MpDeviceLoader(\n",
    "        test_dataloader, \n",
    "        device\n",
    "    )\n",
    "    \n",
    "    should_break = False\n",
    "    for epoch in range(trainer_config.epochs):\n",
    "        xm.master_print(f\"Epoch {epoch} train begin {test_utils.now()}\")\n",
    "        tracker = xm.RateTracker()\n",
    "        \n",
    "        model.train()\n",
    "        for step, batch in enumerate(train_dataloader):\n",
    "            if trainer_config.max_steps is not None and step > trainer_config.max_steps:\n",
    "                should_break = True\n",
    "                break\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            input_ids, attention_mask, labels = (\n",
    "                batch[\"input_ids\"],\n",
    "                batch[\"attention_mask\"],\n",
    "                batch[\"labels\"],\n",
    "            )\n",
    "            xs.mark_sharding(input_ids, mesh, (0, 1))\n",
    "            xs.mark_sharding(attention_mask, mesh, (0, 1))\n",
    "            xs.mark_sharding(labels, mesh, (0, 1))\n",
    "            \n",
    "            output = model(\n",
    "                input_ids=input_ids, attention_mask=attention_mask, labels=labels\n",
    "            )\n",
    "            loss = output.loss\n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            xm.mark_step()\n",
    "            \n",
    "            if step % trainer_config.print_every_n_steps == 0:\n",
    "                loss_cpu = loss.item()\n",
    "                xm.add_step_closure(\n",
    "                    print_training_update,\n",
    "                    args=(step, loss_cpu, epoch)\n",
    "                )\n",
    "        \n",
    "        # UNCOMMENT BELOW TO RUN EVAL.\n",
    "        # model.eval()\n",
    "        # eval_loss = 0\n",
    "        # with torch.no_grad():\n",
    "        #     for step, batch in enumerate(test_dataloader):\n",
    "        #         input_ids, attention_mask, labels = (\n",
    "        #             batch[\"input_ids\"],\n",
    "        #             batch[\"attention_mask\"],\n",
    "        #             batch[\"labels\"],\n",
    "        #         )\n",
    "        #         xs.mark_sharding(input_ids, mesh, (0, 1))\n",
    "        #         xs.mark_sharding(attention_mask, mesh, (0, 1))\n",
    "        #         xs.mark_sharding(labels, mesh, (0, 1))\n",
    "        \n",
    "        #         output = model(\n",
    "        #             input_ids=input_ids, attention_mask=attention_mask, labels=labels\n",
    "        #         )\n",
    "        #         eval_loss += output.loss.item()\n",
    "        # avg_eval_loss = eval_loss / len(test_dataloader)\n",
    "        # xm.add_step_closure(\n",
    "        #     lambda: print(f\"Eval loss: {avg_eval_loss:.4f}\"),\n",
    "        # )\n",
    "        if should_break:\n",
    "            break\n",
    "    result = {'device': xm.get_ordinal(), 'loss': loss.item()}\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class TrainerConfig:    \n",
    "    epochs: int = 1\n",
    "    batch_size: int = 32\n",
    "    seq_length: int = 64\n",
    "    \n",
    "    learning_rate: float = 1e-4\n",
    "\n",
    "    max_steps: int | None = 100\n",
    "    max_examples: int| None = None\n",
    "    \n",
    "    print_every_n_steps: int = 5\n",
    "    \n",
    "    lora_rank: int = 8\n",
    "    lora_alpha: int = 32\n",
    "    lora_dropout: float = 0.1\n",
    "    \n",
    "trainer_config = TrainerConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "import time; start_time = time.time()\n",
    "try:\n",
    "    xmp.spawn(train, args=(), start_method=\"fork\")\n",
    "except Exception as e:\n",
    "    # Catch the expected error of obtaining results from multiple TPU chips when starting distributed training from a notebook.\n",
    "    print()\n",
    "\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"Execution time: {elapsed_time:.4f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export the model to HuggingFace Hub\n",
    "Uncoment the following cell to push the model to HuggingFace Hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 388.410448,
     "end_time": "2023-11-04T13:21:54.038795",
     "exception": false,
     "start_time": "2023-11-04T13:15:25.628347",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "HUGGINGFACE_USERNAME = input(\"Please provide your HUGGINGFACE_USERNAME: \")\n",
    "\n",
    "model = model.cpu()\n",
    "merged_model = model.merge_and_unload()\n",
    "\n",
    "print(\"Uncomment below code if you want to upload to HF.\")\n",
    "# print(\"Uploading to HF...\")\n",
    "# merged_model.push_to_hub(\n",
    "#     f\"{HUGGINGFACE_USERNAME}/felafax-llama3-finetuned\",  # repo name\n",
    "#     tokenizer=tokenizer,\n",
    "#     private=False,\n",
    "#     create_pr=False,\n",
    "#     max_shard_size=\"2GB\",\n",
    "#     token=HUGGINGFACE_TOKEN,\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "tpu1vmV38",
   "dataSources": [
    {
     "databundleVersionId": 7516023,
     "sourceId": 61542,
     "sourceType": "competition"
    },
    {
     "datasetId": 3555678,
     "isSourceIdPinned": true,
     "sourceId": 6196932,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3863727,
     "sourceId": 6703755,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3936750,
     "sourceId": 6847931,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3946973,
     "sourceId": 6867914,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3937441,
     "sourceId": 6868189,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3949797,
     "sourceId": 6873567,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3942644,
     "sourceId": 6890527,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3937250,
     "sourceId": 7017419,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3944051,
     "sourceId": 7060310,
     "sourceType": "datasetVersion"
    }
   ],
   "dockerImageVersionId": 30529,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
